/**
MIT License

Copyright (c) 2022 Augustusmyc
Copyright (c) 2023-2024 Joker2770

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
*/

#include "mcts.h"

#include <float.h>
#include <numeric>
#include <iostream>
#include <random>
#include <assert.h>
#include <algorithm>
#include <cmath>

// TreeNode
TreeNode::TreeNode()
    : parent(nullptr),
      is_leaf(true),
      virtual_loss(0),
      n_visited(0),
      p_sa(0),
      q_sa(0) {}

TreeNode::TreeNode(TreeNode *parent, double p_sa, unsigned int action_size)
    : parent(parent),
      children(action_size, nullptr),
      is_leaf(true),
      virtual_loss(0),
      n_visited(0),
      q_sa(0),
      p_sa(p_sa) {}

TreeNode::TreeNode(
    const TreeNode &node)
{ // because automic<>, define copy function
  // struct
  this->parent = node.parent;
  this->children = node.children;
  this->is_leaf = node.is_leaf;

  this->n_visited.store(node.n_visited.load());
  this->p_sa = node.p_sa;
  this->q_sa = node.q_sa;

  this->virtual_loss.store(node.virtual_loss.load());
}

TreeNode &TreeNode::operator=(const TreeNode &node)
{
  if (this == &node)
  {
    return *this;
  }

  // struct
  this->parent = node.parent;
  this->children = node.children;
  this->is_leaf = node.is_leaf;

  this->n_visited.store(node.n_visited.load());
  this->p_sa = node.p_sa;
  this->q_sa = node.q_sa;
  this->virtual_loss.store(node.virtual_loss.load());

  return *this;
}

// TODO：try random select?
unsigned int TreeNode::select(double c_puct, double c_virtual_loss)
{
  double best_value = -DBL_MAX;
  unsigned int best_move = 0;
  TreeNode *best_node = nullptr;

  for (unsigned int i = 0; i < this->children.size(); i++)
  {
    // empty node
    if (children[i] == nullptr)
    {
      continue;
    }

    unsigned int sum_n_visited = this->n_visited.load() + 1;
    double cur_value =
        children[i]->get_value(c_puct, c_virtual_loss, sum_n_visited);
    if (cur_value > best_value)
    {
      best_value = cur_value;
      best_move = i;
      best_node = children[i];
    }
  }

  // add vitural loss
  if (nullptr != best_node)
    best_node->virtual_loss++;

  return best_move;
}

void TreeNode::expand(const std::vector<double> &action_priors)
{
  {
    // get lock
    std::lock_guard<std::mutex> lock(this->lock);

    if (this->is_leaf)
    {
      unsigned int action_size = (unsigned int)this->children.size();

      for (unsigned int i = 0; i < action_size; i++)
      {
        // illegal action
        if (action_priors[i] < FLT_EPSILON)
        {
          // std::cout << "illegal action " << i << " is: "<<action_priors[i] << std::endl;
          // std::cout << "illegal action " << i << " is: "<<abs(action_priors[i]) << std::endl;
          continue;
        }
        this->children[i] = new TreeNode(this, action_priors[i], action_size);
        // std::cout << "action_priors[i] = " <<i<< action_priors[i] << std::endl;
      }

      // not leaf
      this->is_leaf = false;
    } // else{std::cout << "why not leaf !?" << std::endl;}
  }
}

// void TreeNode::expand(const std::vector<double>& action_priors), std::vector<int>& legal_moves) {
//     {
//         // get lock
//         std::lock_guard<std::mutex> lock(this->lock);
//
//         if (this->is_leaf) {
//             unsigned int action_size = this->children.size();
//
//
//             for (unsigned int i = 0; i < action_size; i++) {
//                 // illegal action
//                 if (legal_moves[i] < FLT_EPSILON)  {
//                     //std::cout << "illegal action " << i << " is: "<<action_priors[i] << std::endl;
//                     //std::cout << "illegal action " << i << " is: "<<abs(action_priors[i]) << std::endl;
//                     continue;
//                 }
//                 this->children[i] = new TreeNode(this, action_priors[i], action_size);
//                 //std::cout << "action_priors[i] = " <<i<< action_priors[i] << std::endl;
//             }
//
//             // not leaf
//             this->is_leaf = false;
//         }//else{std::cout << "why not leaf !?" << std::endl;}
//     }
// }

void TreeNode::backup(double value)
{
  // If it is not root, this node's parent should be updated first
  if (this->parent != nullptr)
  {
    this->parent->backup(-value);
  }

  // remove vitural loss
  this->virtual_loss--;

  // update n_visited
  unsigned int n_visited = this->n_visited.load();
  this->n_visited++;

  // update q_sa
  {
    std::lock_guard<std::mutex> lock(this->lock);
    this->q_sa = (n_visited * this->q_sa + value) / (n_visited + 1);
  }
}

double TreeNode::get_value(double c_puct, double c_virtual_loss,
                           unsigned int sum_n_visited) const
{
  // u
  auto n_visited = this->n_visited.load();
  double u = (c_puct * this->p_sa * sqrt(sum_n_visited) / (1 + n_visited));

  // virtual loss
  double virtual_loss = c_virtual_loss * this->virtual_loss.load();
  // int n_visited_with_loss = n_visited - virtual_loss;

  if (n_visited <= 0)
  {
    return u;
  }
  else
  {
    return u + (this->q_sa * n_visited - virtual_loss) / n_visited;
  }
}

// MCTS
MCTS::MCTS(NeuralNetwork *neural_network, unsigned int thread_num, double c_puct,
           unsigned int num_mcts_sims, double c_virtual_loss,
           unsigned int action_size)
    : neural_network(neural_network),
      thread_pool(new ThreadPool(thread_num)),
      c_puct(c_puct),
      num_mcts_sims(num_mcts_sims),
      c_virtual_loss(c_virtual_loss),
      action_size(action_size),
      rnd_eng(static_cast<unsigned int>(std::time(nullptr))),
      root(new TreeNode(nullptr, 1., action_size), MCTS::tree_deleter) {}

void MCTS::update_with_move(int last_action)
{
  auto old_root = this->root.get();

  // reuse the child tree
  if (last_action >= 0 && old_root->children[last_action] != nullptr)
  {
    // unlink
    TreeNode *new_node = old_root->children[last_action];
    old_root->children[last_action] = nullptr;
    new_node->parent = nullptr;

    this->root.reset(new_node);
  }
  else
  {
    this->root.reset(new TreeNode(nullptr, 1., this->action_size));
  }
}

void MCTS::tree_deleter(TreeNode *t)
{
  if (t == nullptr)
  {
    return;
  }

  // remove children
  for (unsigned int i = 0; i < t->children.size(); i++)
  {
    if (t->children[i])
    {
      tree_deleter(t->children[i]);
    }
  }

  // remove self
  delete t;
}

std::vector<double> MCTS::get_action_probs(Gomoku *gomoku, double temp)
{
  // submit simulate tasks to thread_pool
  std::vector<std::future<void>> futures;

  for (unsigned int i = 0; i < this->num_mcts_sims; i++)
  {
    // copy gomoku
    auto game = std::make_shared<Gomoku>(*gomoku);
    auto future =
        this->thread_pool->commit(std::bind(&MCTS::simulate, this, game));
    // std::move(future).wait();
    // std::cout << "wait simulate num = " << i << std::endl;

    // future can't copy
    futures.emplace_back(std::move(future));
  }
  // wait simulate
  for (unsigned int i = 0; i < futures.size(); i++)
  {
    // std::cout << "wait simulate num = " << i << std::endl;
    futures[i].wait();
  }

  // calculate probs
  std::vector<double> action_probs(gomoku->get_action_size(), 0);
  const auto &children = this->root->children;

  // greedy
  if (fabs(temp - 1e-3) < FLT_EPSILON)
  {
    unsigned int max_count = 0;
    unsigned int best_action = 0;

    for (unsigned int i = 0; i < children.size(); i++)
    {
      if (children[i] && children[i]->n_visited.load() > max_count)
      {
        max_count = children[i]->n_visited.load();
        best_action = i;
        // std::cout << "best_action change to be = " << best_action << std::endl;
      }
    }

    action_probs[best_action] = 1.;
    return action_probs;
  }
  else
  {
    // explore
    double sum = 0;
    for (unsigned int i = 0; i < children.size(); i++)
    {
      if (children[i] && children[i]->n_visited.load() > 0)
      {
        // std::cout << "children[i] = " << i << std::endl;
        action_probs[i] = pow(children[i]->n_visited.load(), 1. / temp);
        sum += action_probs[i];
      }
    }
    // renormalization
    std::for_each(action_probs.begin(), action_probs.end(),
                  [sum](double &x)
                  { x /= sum; });

    return action_probs;
  }
}

int MCTS::get_best_action_from_prob(std::vector<double> &action_probs)
{
  int best_action = -1;
  double best_prob = -1.0f;

  for (unsigned int i = 0; i < BOARD_SIZE * BOARD_SIZE; i++)
  {
    if (action_probs[i] > best_prob)
    {
      best_prob = action_probs[i];
      best_action = i;
    }
  }
  assert(best_action >= 0);
  return best_action;
}

int MCTS::get_best_action(Gomoku *gomoku)
{
  // submit simulate tasks to thread_pool
  std::vector<std::future<void>> futures;

  for (unsigned int i = 0; i < this->num_mcts_sims; i++)
  {
    // copy gomoku
    auto game = std::make_shared<Gomoku>(*gomoku);
    auto future =
        this->thread_pool->commit(std::bind(&MCTS::simulate, this, game));

    // future can't copy
    futures.emplace_back(std::move(future));
  }
  // wait simulate
  for (unsigned int i = 0; i < futures.size(); i++)
  {
    // std::cout << "wait simulate num = " << i << std::endl;
    futures[i].wait();
  }

  // calculate probs
  // std::vector<double> action_probs(gomoku->get_action_size(), 0);
  const auto &children = this->root->children;

  // greedy
  unsigned int max_count = 0;
  int best_action = -1;

  for (unsigned int i = 0; i < children.size(); i++)
  {
    if (children[i] && children[i]->n_visited.load() > max_count)
    {
      max_count = children[i]->n_visited.load();
      best_action = i;
    }
  }
  assert(best_action >= 0);
  return best_action;
}

int MCTS::get_action_by_sample(std::vector<double> &probs)
{
  double r = rnd_dis(rnd_eng);
  int index = 0;
  double accum = 0.0f;
  for (unsigned int i = 0; i < action_size; i++)
  {
    accum += probs[i];
    if (accum > r)
    {
      index = i;
      break;
    }
  }
  return index;
}

void MCTS::simulate(std::shared_ptr<Gomoku> game)
{
  // execute one simulation
  auto node = this->root.get();

  while (true)
  {
    if (node->get_is_leaf())
    {
      // std::cout << "node is leaf, break" << std::endl;
      break;
    }

    // select
    auto action = node->select(this->c_puct, this->c_virtual_loss);
    game->execute_move(action);
    node = node->children[action];
  }

  // get game status
  auto status = game->get_game_status();
  double value = 0;

  // not end
  if (status.first == 0)
  {
    // predict action_probs and value by neural network
    std::vector<double> action_priors(this->action_size, 0);
    // mask invalid actions
    auto legal_moves = game->get_legal_moves();

    if (this->neural_network != nullptr)
    {
      auto future = this->neural_network->commit(game.get());
      auto result = future.get();

      action_priors = std::move(result[0]);
      value = result[1][0];
    }
    else
    {
      // std::cout<< "not use neural_network!!!!" <<std::endl;
      double sum = std::accumulate(legal_moves.begin(), legal_moves.end(), 0);
      for (unsigned int i = 0; i < action_priors.size(); i++)
      {
        action_priors[i] = legal_moves[i] / sum;
      }
    }

    double sum = 0;
    for (unsigned int i = 0; i < action_priors.size(); i++)
    {
      if (legal_moves[i] == 1)
      {
        // action_priors[i] += 12*FLT_EPSILON;
        sum += action_priors[i];
      }
      else
      {
        action_priors[i] = 0;
      }
    }

    // renormalization
    if (sum > FLT_EPSILON)
    {
      std::for_each(action_priors.begin(), action_priors.end(),
                    [sum](double &x)
                    { x /= sum; });
    }
    else
    {
      // all masked

      // NB! All valid moves may be masked if either your NNet architecture is
      // insufficient or you've get overfitting or something else. If you have
      // got dozens or hundreds of these messages you should pay attention to
      // your NNet and/or training process.
      std::cout << "Check training process!! All valid moves were masked, do workaround." << std::endl;

      sum = std::accumulate(legal_moves.begin(), legal_moves.end(), 0);
      for (unsigned int i = 0; i < action_priors.size(); i++)
      {
        action_priors[i] = legal_moves[i] / sum;
      }
    }

    // expand
    node->expand(action_priors);
  }
  else
  {
    // end
    auto winner = status.second;
    value = (winner == 0 ? 0 : (winner == game->get_current_color() ? 1 : -1));
  }

  // value(parent -> node) = -value
  node->backup(-value);
  return;
}
